import numpy as np

a = [[1, 2, 3], [4, 5, 6],[7,8,9]]
a = np.array(a)
print(a.shape)
print(a)
b = [0,2]
print(a[b,:])
# C = np.stack(a,0)
# D = np.stack(a,1)
# print(C.shape)
# print(C)
# print(D.shape)
# print(D)
